import random
import os
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.utils.data
from torch import nn
import torch.autograd as autograd
import numpy as np
import time
from model import CNNModel
import utils
import math
import argparse
from data_loader import GetLoader
from torchvision import datasets
from torchvision import transforms
from model import CNNModel
from AGD import *
from AGD1 import *

# parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=float, default=1)
parser.add_argument('--lr_x', type=float, default=1e-2)
parser.add_argument('--lr_y', type=float, default=1e-2)
parser.add_argument('--lr_cubic', type=float, default=1e-2)
parser.add_argument('--lam', type=float, default=1)
parser.add_argument('--feature_size', type=int, default=100)
parser.add_argument('--epoch', type=int, default=10000)
parser.add_argument('--cubicepoch', type=int, default=300)
parser.add_argument('--agdepoch', type=int, default=300)
parser.add_argument('--chebepoch', type=int, default=50)
parser.add_argument('--cauchypoint', type=float, default=1.0)
parser.add_argument('--step', type=int, default=1)

args = parser.parse_args()

M = args.M
lr_x = args.lr_x
lr_y = args.lr_y
lr_cubic = args.lr_cubic
lam = args.lam
feature_size = args.feature_size
n_epoch = args.epoch
cubic_epoch = args.cubicepoch
agd_maxiter = args.agdepoch
cauchypoint = args.cauchypoint
step = args.step
cheb = args.chebepoch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

source_dataset_name = 'MNIST'
target_dataset_name = 'mnist_m'

source_image_root = os.path.join('dataset', source_dataset_name)
target_image_root = os.path.join('dataset', target_dataset_name)

cuda = True if torch.cuda.is_available() else False


image_size = 28


manual_seed = 8
random.seed(manual_seed)
torch.manual_seed(manual_seed)


# load data

img_transform_source = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])

img_transform_target = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

dataset_source = datasets.MNIST(
    root='dataset',
    train=True,
    transform=img_transform_source,
    download=True
)

dataloader_source = torch.utils.data.DataLoader(
    dataset=dataset_source,
    batch_size=len(dataset_source),
    num_workers=8)

train_list = os.path.join(target_image_root, 'mnist_m_train_labels.txt')

dataset_target = GetLoader(
    data_root=os.path.join(target_image_root, 'mnist_m_train'),
    data_list=train_list,
    transform=img_transform_target
)

dataloader_target = torch.utils.data.DataLoader(
    dataset=dataset_target,
    batch_size=len(dataset_target),
    shuffle=True,
    num_workers=8)

data_source_iter = iter(dataloader_source)
data_target_iter = iter(dataloader_target)

# load model

my_net = CNNModel(feature_size).double()
my_logi = torch.nn.Linear(feature_size, 1, bias=False).double()
my_logi_phi = torch.nn.Linear(feature_size, 1, bias=False).double()

# setup optimizer
#logi_optimizer = optim.SGD(my_logi.parameters(), lr=lr_y, momentum=0.99, nesterov=True)
logi_optimizer = AGD1(my_logi.parameters(), lr=lr_y, theta=0.9, weight_decay=lam)
logi_phi_optimizer = optim.SGD(my_logi_phi.parameters(), lr=0.001, momentum=0.99, nesterov=True)
#logi_phi_optimizer = AGD(my_logi_phi.parameters(), lr=lr_y, theta=0.9)

loss_class = nn.NLLLoss()
loss_domain = nn.BCEWithLogitsLoss()

if cuda:
    my_net = my_net.cuda()
    my_logi = my_logi.cuda()
    my_logi_phi = my_logi_phi.cuda()
    loss_class = loss_class.cuda()
    loss_domain = loss_domain.cuda()

for p in my_net.parameters():
    p.requires_grad = True

for p in my_logi.parameters():
    p.requires_grad = True

for p in my_logi_phi.parameters():
    p.requires_grad = True

print('Load model done!')


# load source data
print('Loading source data...')
data_source = data_source_iter.next()
s_img, s_label = data_source
s_img = s_img.view(-1,28*28)
s_img = s_img.double()
print('Finish loading source data.')

# load target data
print('Loading target data...')
data_target = data_target_iter.next()
t_img, t_label = data_target
t_img = 0.299 * t_img[:,0,:,:] + 0.587 * t_img[:,1,:,:] + 0.114 * t_img[:,1,:,:]
t_img = t_img.view(-1,28*28)
t_img = t_img.double()
print('Finish loading target data.')

s_domain_label = torch.zeros(len(s_label))
t_domain_label = torch.ones(len(t_img))

if cuda:
    s_img = s_img.cuda()
    t_img = t_img.cuda()
    s_label = s_label.cuda()
    t_label = t_label.cuda()
    s_domain_label = s_domain_label.cuda()
    t_domain_label = t_domain_label.cuda()

s_data_num = s_img.shape[0]
t_data_num = t_img.shape[0]

P_result = []
x_grad_result = []
y_grad_result = []
time_result = []
oracle_result = []


total_time = 0
total_oracle = 0



time_result.append(0.0)


s_fea, class_output = my_net(input_data=s_img)
t_fea, _ = my_net(input_data=t_img)

err_s_label = loss_class(class_output, s_label)

s_domain_output = my_logi(s_fea)
t_domain_output = my_logi(t_fea)

err_s_domain = loss_domain(s_domain_output.squeeze(), s_domain_label)
err_t_domain = loss_domain(t_domain_output.squeeze(), t_domain_label)

err = err_s_label - err_t_domain - err_s_domain
grad_x = autograd.grad(err, my_net.parameters())
ll = len(nn.utils.parameters_to_vector(grad_x[0:2]))

id = torch.eye(s_fea.shape[1]).double().to(device)
flat_grad_x = nn.utils.parameters_to_vector(grad_x)
x_num = len(flat_grad_x)
supp = torch.zeros(x_num - ll).double().to(device)
my_net.zero_grad()




# training

start_time = time.time()


for epoch in range(n_epoch):



    # Perform AGD on Y
    with torch.no_grad():
        s_fea, _ = my_net(input_data=s_img)
        t_fea, _ = my_net(input_data=t_img)


    for i in range(agd_maxiter):
        #my_logi.zero_grad()
        logi_optimizer.zero_grad()

        s_domain_output = my_logi(s_fea)
        t_domain_output = my_logi(t_fea)

        err_s_domain = loss_domain(s_domain_output.squeeze(), s_domain_label)
        err_t_domain = loss_domain(t_domain_output.squeeze(), t_domain_label)

        err_domain = err_s_domain + err_t_domain  # + lam / 2 * (my_logi.weight.norm() ** 2)

        err_domain.backward()
        logi_optimizer.step()

    total_oracle += agd_maxiter

    ## AGD finished


    my_net.zero_grad()
    my_logi.zero_grad()

    ## Compute err

    s_fea, class_output = my_net(input_data=s_img)
    t_fea, _ = my_net(input_data=t_img)

    err_s_label = loss_class(class_output, s_label)

    s_domain_output = my_logi(s_fea)
    t_domain_output = my_logi(t_fea)

    err_s_domain = loss_domain(s_domain_output.squeeze(), s_domain_label)
    err_t_domain = loss_domain(t_domain_output.squeeze(), t_domain_label)

    err = err_s_label - err_t_domain - err_s_domain - lam / 2 * (my_logi.weight.norm() ** 2)


    grad_x = autograd.grad(err, my_net.parameters(), create_graph=True)
    s_fea.detach()
    t_fea.detach()
    grad_y = autograd.grad(err, my_logi.weight, create_graph=True)

    flat_grad_x = nn.utils.parameters_to_vector(grad_x)
    flat_grad_y = nn.utils.parameters_to_vector(grad_y)
    if (epoch + 1) % step == 0:
        end_time = time.time()
        total_time += end_time - start_time

        x_grad = flat_grad_x.norm().detach().cpu().numpy()
        y_grad = flat_grad_y.norm().detach().cpu().numpy()
        x_grad_result.append(x_grad)
        y_grad_result.append(y_grad)

        start_time = time.time()




    # Compute Hessian
    with torch.no_grad():
        s_D = torch.sigmoid(my_logi(s_fea)) * (1 - torch.sigmoid(my_logi(s_fea)))
        t_D = torch.sigmoid(my_logi(t_fea)) * (1 - torch.sigmoid(my_logi(t_fea)))
        s_D = s_D.squeeze() / s_data_num
        t_D = t_D.squeeze() / t_data_num

        hess = (s_fea.T * s_D) @ s_fea + (t_fea.T * t_D) @ t_fea + lam * id

        l_est = s_fea.norm() ** 2 * torch.max(s_D) + t_fea.norm() ** 2 * torch.max(t_D) + lam
        mu_est = lam


    # end Hessian
    chebepoch = min(max(math.ceil(math.sqrt(l_est / mu_est)), 5), cheb)



    ## Cubic subsolver
    gnorm = flat_grad_x.norm()
    s = torch.zeros_like(flat_grad_x).cuda()

    if gnorm >= cauchypoint:
        step = args.step
        total_oracle += chebepoch


        n_flat_grad_x = flat_grad_x / gnorm
        Hg = utils.hessian_vector_product2(supp, flat_grad_x, flat_grad_y, my_net, my_logi.weight, hess, n_flat_grad_x, l_est, mu_est, chebepoch)
        with torch.no_grad():
            gHg = (n_flat_grad_x * Hg).sum()
            tmp  = gHg / M
            #tmp = gHg / (M * (gnorm ** 2))
            Rc = -tmp + torch.sqrt(tmp ** 2 + 2 * gnorm / M)
            s = -Rc * n_flat_grad_x

    else:
        step = 2



        for i in range(cubic_epoch):


            Hs = utils.hessian_vector_product2(supp, flat_grad_x, flat_grad_y, my_net, my_logi.weight, hess, s, l_est, mu_est, chebepoch)

            with torch.no_grad():
                cubic_grad = flat_grad_x + Hs + M / 2 * s.norm() * s
                s = s - lr_cubic * cubic_grad

            if cubic_grad.norm() < 1e-4:
                #print('cubic break iteration =', i)
                break

        total_oracle += cubic_epoch * chebepoch


    with torch.no_grad():

        flat_x = nn.utils.parameters_to_vector(my_net.parameters())
        flat_x += s
        nn.utils.vector_to_parameters(flat_x, my_net.parameters())


    if (epoch + 1) % step == 0:

        end_time = time.time()
        total_time += end_time - start_time
        time_result.append(total_time)


        ## Perform AGD
        my_logi_phi.load_state_dict(my_logi.state_dict())

        with torch.no_grad():
            s_fea_phi, _ = my_net(input_data=s_img)
            t_fea_phi, _ = my_net(input_data=t_img)

        for i in range(1000):
            my_logi_phi.zero_grad()
            logi_phi_optimizer.zero_grad()

            s_domain_output_phi = my_logi_phi(s_fea_phi)
            t_domain_output_phi = my_logi_phi(t_fea_phi)

            err_s_domain_phi = loss_domain(s_domain_output_phi.squeeze(), s_domain_label)
            err_t_domain_phi = loss_domain(t_domain_output_phi.squeeze(), t_domain_label)

            err_domain = err_s_domain_phi + err_t_domain_phi + lam / 2 * (my_logi.weight.norm() ** 2)

            err_domain.backward()
            logi_phi_optimizer.step()

        #print('agd grad = ', my_logi_phi.weight.grad.norm().cpu().numpy())

        ## Compute P function
        with torch.no_grad():
            s_fea, class_output = my_net(input_data=s_img)
            err_s_label = loss_class(class_output, s_label)

            s_domain_output = my_logi_phi(s_fea)
            err_s_domain = loss_domain(s_domain_output.squeeze(), s_domain_label)

            t_fea, _ = my_net(input_data=t_img)
            t_domain_output = my_logi_phi(t_fea)
            err_t_domain = loss_domain(t_domain_output.squeeze(), t_domain_label)

            err = err_s_label - err_t_domain - err_s_domain - lam / 2 * (my_logi.weight.norm() ** 2)
            err = err.detach().cpu().numpy()
            P_result.append(err)
            oracle_result.append(total_oracle)
            

        print('Epoch =', epoch, 'agd grad = ', my_logi_phi.weight.grad.norm().cpu().numpy(), 'x_grad =', x_grad,
              'y_grad =', y_grad, 'P =', err, 'Time =', total_time, 'Oracle =', total_oracle)

        #print('Epoch =', epoch, 'x_grad =', x_grad, 'y_grad =', y_grad, 'P =', err, 'Time =', total_time, 'Oracle =', total_oracle)
        if total_time > 2000:
            break
        #print('Epoch =', epoch, 'P =', err, 'Time =', total_time)

        start_time = time.time()

filename = str(lam) + '_' +str(feature_size) + '_' + str(M) + '_' + str(cubic_epoch) + '_' + str(agd_maxiter) + '_' + str(cauchypoint) +  '_imcn_mnist_mnist_m.npz'
np.savez(filename, P=P_result, Time=time_result, xgrad=x_grad_result, ygrad=y_grad_result, oracle=oracle_result)
